#!/usr/bin/env python3
"""
Run the Volume-4 Wilson-loop pipeline lattice sweep inside the FPHS monorepo.

Usage:
  python vol4_wilson_loop_pipeline_lattice_sweep/run.py \
    --config configs/default.yaml \
    --output-dir data/results/vol4_wilson_loop_pipeline_lattice_sweep
"""

from __future__ import annotations
import argparse
import os
import sys
import glob
from typing import Dict, Any, Tuple
import numpy as np

# Ensure repo root on sys.path (parent of this script's directory)
_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _REPO_ROOT not in sys.path:
    sys.path.insert(0, _REPO_ROOT)

from sim_utils import load_config, sweep_iter, seed_all, save_csv
from orig.simulation.build_lattice import build_lattice
from orig.simulation.compute_Amu import logistic_D, linear_gD
from orig.simulation.compute_Umu import compute_U_from_A
from orig.simulation.measure_wilson import measure_wilson_loops
from orig.simulation.plot_results import fit_string_tension


def _repo_root_from_config(config_path: str) -> str:
    cfg_dir = os.path.dirname(os.path.abspath(config_path))
    return os.path.dirname(cfg_dir)  # repo root


def _abs(repo_root: str, p: str) -> str:
    return p if os.path.isabs(p) else os.path.join(repo_root, p)


def resolve_flip_counts_path(cfg: Dict[str, Any], L: int, repo_root: str) -> str:
    """
    Prefer per-L template; then auto-discover; finally single path.
    Raise if nothing is found.
    """
    # 1) Per-L template at top level (your YAML has this)
    tmpl = cfg.get("flip_counts_path_template")
    if tmpl:
        fc_rel = tmpl.format(L=L, **{k: cfg.get(k) for k in ("b", "k", "n0")})
        fc_path = _abs(repo_root, fc_rel)
        if os.path.exists(fc_path):
            return fc_path

        # 2) Auto-discover common layout: .../**/L{L}/flip_counts.npy
        #   Works with Windows paths too.
        candidate_glob = os.path.join(
            repo_root, "data", "results", "vol4_loop_fluctuation_sim", "**", f"L{L}", "flip_counts.npy"
        )
        hits = glob.glob(candidate_glob, recursive=True)
        if hits:
            return hits[0]

    # 3) Single global path fallback
    single = cfg.get("flip_counts_path")
    if single:
        fc_path = _abs(repo_root, single)
        if os.path.exists(fc_path):
            return fc_path

    raise FileNotFoundError(
        f"flip_counts not found for L={L}.\n"
        f"  Tried template: {tmpl!r}\n"
        f"  Tried single  : {cfg.get('flip_counts_path')!r}\n"
        f"  Also searched : data/results/vol4_loop_fluctuation_sim/**/L{L}/flip_counts.npy"
    )


def resolve_kernel_path(cfg: Dict[str, Any], cross_cfg: Dict[str, Any],
                        gauge: str, L: int, repo_root: str) -> str:
    """
    Resolve the kernel path in this priority:
      1) crossover_analysis.kernel_path_template (string: may use {gauge},{L})
         or dict (per-gauge template with {L})
      2) top-level kernel_path_template (dict per gauge or string format)
      3) top-level kernel_paths (single file per gauge; not per-L)
    Raise if nothing resolves to an existing file.
    """
    # 1) crossover_analysis.kernel_path_template
    kpt = cross_cfg.get("kernel_path_template")
    if isinstance(kpt, str):
        p = _abs(repo_root, kpt.format(gauge=gauge, L=L))
        if os.path.exists(p):
            return p
    elif isinstance(kpt, dict):
        tmpl = kpt.get(gauge)
        if tmpl:
            p = _abs(repo_root, tmpl.format(L=L))
            if os.path.exists(p):
                return p

    # 2) top-level kernel_path_template
    kpt_top = cfg.get("kernel_path_template")
    if isinstance(kpt_top, str):
        p = _abs(repo_root, kpt_top.format(gauge=gauge, L=L))
        if os.path.exists(p):
            return p
    elif isinstance(kpt_top, dict):
        tmpl = kpt_top.get(gauge)
        if tmpl:
            p = _abs(repo_root, tmpl.format(L=L))
            if os.path.exists(p):
                return p

    # 3) top-level kernel_paths (no L)
    kp = cfg.get("kernel_paths", {})
    if isinstance(kp, dict):
        static_p = kp.get(gauge)
        if static_p:
            p = _abs(repo_root, static_p)
            if os.path.exists(p):
                return p

    raise FileNotFoundError(
        f"kernel not found for gauge={gauge}, L={L}.\n"
        f"  Tried crossover_analysis.kernel_path_template={kpt!r}\n"
        f"  Tried top-level kernel_path_template={kpt_top!r}\n"
        f"  Tried top-level kernel_paths[{gauge!r}]={kp.get(gauge)!r}"
    )


def run_lattice_sweep(cfg: dict, b: float, k: float, n0: float, L: int,
                      lattice_size: int, output_dir: str, config_path: str) -> None:
    repo_root = _repo_root_from_config(config_path)

    # Precompute lattice and sizes
    cross_cfg = cfg.get("crossover_analysis", {})
    bc = cross_cfg.get("boundary_conditions", "periodic")
    lattice = build_lattice(lattice_size, boundary=bc)
    N_links = len(lattice)

    # Flip counts (required)
    fc_path = resolve_flip_counts_path(cfg, L=L, repo_root=repo_root)
    flip_counts_raw = np.load(fc_path, allow_pickle=True)
    flip_counts = np.resize(flip_counts_raw, N_links)

    # Pivot parameters (from crossover_analysis.pivot)
    pivot_cfg = cross_cfg.get("pivot", {})
    a_p      = pivot_cfg.get("a", 1.0)
    b_p      = pivot_cfg.get("b", 0.0)
    log_k    = pivot_cfg.get("logistic_k", 1.0)
    log_n0   = pivot_cfg.get("logistic_n0", 0.0)

    # Wilson-loop set-up
    gauge_groups = cross_cfg.get("gauge_groups", ["U1"])
    loop_sizes   = cross_cfg.get("loop_sizes", [])
    if not loop_sizes:
        raise ValueError("crossover_analysis.loop_sizes is empty – nothing to measure.")

    g_coupling = cfg.get("g", 1.0)

    # Pivot weights along links
    D_vals  = logistic_D(flip_counts, log_k, log_n0)
    gD_vals = linear_gD(D_vals, a_p, b_p)

    planned_gauges = 0
    for gauge in gauge_groups:
        kpath = resolve_kernel_path(cfg, cross_cfg, gauge, L, repo_root)

        kernel_raw = np.load(kpath, allow_pickle=True)
        if kernel_raw.ndim == 1:
            kernel = np.resize(kernel_raw, N_links)
        else:
            kernel = np.resize(kernel_raw, (N_links,) + kernel_raw.shape[1:])

        # A_mu
        if kernel.ndim > 1:
            A = (g_coupling * gD_vals)[:, None, None] * kernel
        else:
            A = g_coupling * gD_vals * kernel

        # U_mu and loops
        U = compute_U_from_A(A, gauge_group=gauge)
        loops = measure_wilson_loops(lattice, U, loop_sizes, bc=bc)

        # Fit area-law string tension
        sigma, ci95 = fit_string_tension(loops)

        row = {
            "b": b, "k": k, "n0": n0, "L": L,
            "lattice_size": lattice_size,
            "gauge_group": gauge,
            "string_tension": sigma,
            "ci95": ci95,
        }
        out_path = os.path.join(output_dir, "string_tension_summary.csv")
        save_csv(out_path, row)
        planned_gauges += 1

    if planned_gauges == 0:
        raise RuntimeError("Planned 0 gauge jobs – check gauge_groups, kernels, and loop_sizes.")


def main() -> None:
    parser = argparse.ArgumentParser(description="Run the lattice-sweep Wilson-loop simulation.")
    parser.add_argument("--config", default="configs/default.yaml", help="Path to the YAML configuration file.")
    parser.add_argument("--output-dir", default="data/results/vol4_wilson_loop_pipeline_lattice_sweep",
                        help="Directory to write CSV results.")
    args = parser.parse_args()

    cfg = load_config(args.config)

    # Per-module lattice sizes (or fallback to global L_values)
    mod_cfg = cfg.get("vol4_wilson_loop_pipeline_lattice_sweep", {})
    lattice_sizes = list(mod_cfg.get("lattice_sizes", [])) or list(cfg.get("L_values", []))
    if not lattice_sizes:
        raise ValueError("No lattice_sizes configured (module or global).")

    os.makedirs(args.output_dir, exist_ok=True)

    # Preflight log
    cross_cfg = cfg.get("crossover_analysis", {})
    print("[lattice-sweep] config     :", args.config)
    print("[lattice-sweep] out dir    :", args.output_dir)
    print("[lattice-sweep] gauges     :", cross_cfg.get("gauge_groups", ["U1"]))
    print("[lattice-sweep] loop sizes :", cross_cfg.get("loop_sizes", []))
    print("[lattice-sweep] L_values   :", cfg.get("L_values", []))
    print("[lattice-sweep] lattice_sz :", lattice_sizes)
    print("[lattice-sweep] flip tmpl  :", cfg.get("flip_counts_path_template"))
    print("[lattice-sweep] kernel tmpl:", cross_cfg.get("kernel_path_template") or cfg.get("kernel_path_template"))

    # Iterate sweep
    if cfg.get("b_values") and cfg.get("k_values") and cfg.get("n0_values") and cfg.get("L_values"):
        iterator = sweep_iter(cfg)  # (b,k,n0,L)
    else:
        iterator = [(0.0, 0.0, 0.0, L) for L in cfg.get("L_values", [])]

    planned = 0
    for (b, k, n0, L) in iterator:
        seed_all(b or 0.0, k or 0.0, n0 or 0.0, L or 0)
        for lattice_size in lattice_sizes:
            run_lattice_sweep(cfg, float(b or 0.0), float(k or 0.0), float(n0 or 0.0),
                              int(L or lattice_size), int(lattice_size), args.output_dir, args.config)
            planned += 1

    print(f"[lattice-sweep] completed {planned} jobs.")


if __name__ == "__main__":
    main()
